"""""""""
Pytorch implementation of "A simple neural network module for relational reasoning
Code is based on pytorch/examples/mnist (https://github.com/pytorch/examples/tree/master/mnist)
"""""""""
from __future__ import print_function
import argparse
import os

import pickle
import random
import numpy as np
import csv

import torch
from torch.utils.tensorboard import SummaryWriter
from torch.autograd import Variable

from model import *
import wandb

# Training settings
parser = argparse.ArgumentParser(description='PyTorch Relational-Network sort-of-CLVR Example')
parser.add_argument('--model', type=str, choices=['Transformer', 'Compositional'], default='Transformer',
                    help='resume from model stored')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                    help='input batch size for training (default: 64)')
parser.add_argument('--epochs', type=int, default=100, metavar='N',
                    help='number of epochs to train (default: 20)')
parser.add_argument('--lr', type=float, default=0.0001, metavar='LR',
                    help='learning rate (default: 0.0001)')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('--relation-type', type=str, default='binary',
                    help='what kind of relations to learn. options: binary, ternary (default: binary)')
parser.add_argument('--transformer-dim', type=int, default=128)
parser.add_argument('--qk-dim', type=int, default=16)
parser.add_argument('--iterations', default=1, type=int,
                    help='Number of Transformer Iterations to use in the relational base')
parser.add_argument('--n-heads', type=int, default=4)
parser.add_argument('--n-rules', type=int, default=1)
parser.add_argument('--dot', action='store_true', default=False)
parser.add_argument('--name', type=str, default='Default')
args = parser.parse_args()

args.cuda = not args.no_cuda and torch.cuda.is_available()

config = {
    "Transformer Dimension": args.transformer_dim,
    "Number of Heads": args.n_heads,
    "Number of Rules": args.n_rules,
    "Iterations": args.iterations,
    "Model": args.model,
    "Seed": args.seed,
    "qk-dim": args.qk_dim,
    "dot": args.dot,
    "lr": args.lr,
}

wandb.init(settings=wandb.Settings(start_method='fork'),
           project="Sort-of-CLEVR-HeadAblation", config=config,
           name=args.name)

print(args)


def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)


set_seed(args.seed)

folder_name = f'logs/{args.name}'
tensorboard_dir = f'{folder_name}/tensorboard/'
model_dir = f'{folder_name}/checkpoints/'

if not os.path.isdir(folder_name):
    os.makedirs(folder_name)
    os.makedirs(tensorboard_dir)
    os.makedirs(model_dir)

summary_writer = SummaryWriter(tensorboard_dir)

model = Model(args)

print(model)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Number of Parameters: ", n_params)

wandb.watch(model)

bs = args.batch_size
input_img = torch.FloatTensor(bs, 3, 75, 75)
input_qst = torch.FloatTensor(bs, 18)
label = torch.LongTensor(bs)

if args.cuda:
    model.cuda()
    input_img = input_img.cuda()
    input_qst = input_qst.cuda()
    label = label.cuda()

input_img = Variable(input_img)
input_qst = Variable(input_qst)
label = Variable(label)


def tensor_data(data, i):
    img = torch.from_numpy(np.asarray(data[0][bs * i:bs * (i + 1)]))
    qst = torch.from_numpy(np.asarray(data[1][bs * i:bs * (i + 1)]))
    ans = torch.from_numpy(np.asarray(data[2][bs * i:bs * (i + 1)]))

    input_img.data.resize_(img.size()).copy_(img)
    input_qst.data.resize_(qst.size()).copy_(qst)
    label.data.resize_(ans.size()).copy_(ans)


def cvt_data_axis(data):
    img = [e[0] for e in data]
    qst = [e[1] for e in data]
    ans = [e[2] for e in data]
    return (img, qst, ans)


def train(epoch, rel):
    model.train()

    random.shuffle(rel)
    rel = cvt_data_axis(rel)

    acc_rels = []

    l_binary = []

    last = len(rel[0]) // bs

    for batch_idx in range(last):

        tensor_data(rel, batch_idx)
        accuracy_rel, loss_binary = model.train_(input_img, input_qst, label)
        acc_rels.append(accuracy_rel.item())
        l_binary.append(loss_binary.item())

        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)] '
                  'Relations accuracy: {:.0f}%'.format(
                epoch,
                batch_idx * bs * 2,
                len(rel[0]) * 2,
                100. * batch_idx * bs / len(rel[0]),
                accuracy_rel))

    avg_acc_binary = sum(acc_rels) / len(acc_rels)

    summary_writer.add_scalars('Accuracy/train', {
        'binary': avg_acc_binary,
    }, epoch)

    avg_loss_binary = sum(l_binary) / len(l_binary)

    summary_writer.add_scalars('Loss/train', {
        'binary': avg_loss_binary,
    }, epoch)

    # return average accuracy
    return avg_acc_binary


def test(epoch, rel, split='Test'):
    model.eval()

    rel = cvt_data_axis(rel)
    accuracy_rels = []

    loss_binary = []

    for batch_idx in range(len(rel[0]) // bs):
        tensor_data(rel, batch_idx)
        acc_bin, l_bin = model.test_(input_img, input_qst, label)
        accuracy_rels.append(acc_bin.item())
        loss_binary.append(l_bin.item())

    accuracy_rel = sum(accuracy_rels) / len(accuracy_rels)
    print('{} set: Binary accuracy: {:.0f}%'.format(
        split, accuracy_rel))

    summary_writer.add_scalars(f'Accuracy/{split}', {
        'binary': accuracy_rel,
    }, epoch)

    loss_binary = sum(loss_binary) / len(loss_binary)

    summary_writer.add_scalars('Loss/test', {
        'binary': loss_binary,
    }, epoch)

    return accuracy_rel


def load_data():
    print('loading data...')
    dirs = '/miniscratch/mittalsa/data/data'
    # dirs = '/miniscratch/mittalsa/data/old_data'
    filename = os.path.join(dirs, 'sort-of-clevr.pickle')
    with open(filename, 'rb') as f:
        train_datasets, val_datasets, test_datasets = pickle.load(f)

    rel_train = []
    rel_val = []
    rel_test = []

    print('processing data...')

    for img, relations in train_datasets:
        img = np.swapaxes(img, 0, 2)
        for qst, ans in zip(relations[0], relations[1]):
            rel_train.append((img, qst, ans))

    for img, relations in val_datasets:
        img = np.swapaxes(img, 0, 2)
        for qst, ans in zip(relations[0], relations[1]):
            rel_val.append((img, qst, ans))

    for img, ternary, relations, norelations in test_datasets:
        img = np.swapaxes(img, 0, 2)
        for qst, ans in zip(relations[0], relations[1]):
            rel_test.append((img, qst, ans))

    return rel_train, rel_val, rel_test


rel_train, rel_val, rel_test = load_data()

best_val_ternary, best_val_rel, best_val_norel = float('-inf'), float('-inf'), float('-inf')
opt_ternary, opt_rel, opt_norel = float('-inf'), float('-inf'), float('-inf')

with open(f'./{folder_name}/log.csv', 'w') as log_file:
    csv_writer = csv.writer(log_file, delimiter=',')
    csv_writer.writerow(['epoch',
                         'train_acc_rel',
                         'val_acc_rel',
                         'best_val_acc_rel',
                         'test_acc_rel',
                         'optimal_test_acc_rel'])

    print(f"Training {args.model} model...")
    for epoch in range(1, args.epochs + 1):

        train_acc_ternary, train_acc_binary, train_acc_unary = train(
            epoch, rel_train)
        print()

        val_acc_ternary, val_acc_binary, val_acc_unary = test(
            epoch, rel_val, split='Val')
        test_acc_ternary, test_acc_binary, test_acc_unary = test(
            epoch, rel_test, split='Test')

        if val_acc_binary > best_val_rel:
            best_val_rel = val_acc_binary
            opt_rel = test_acc_binary

        dict = {
            "Binary Train Accuracy": train_acc_binary,
            "Binary Val Accuracy": val_acc_binary,
            "Best Binary Val Accuracy": best_val_rel,
            "Binary Test Accuracy": test_acc_binary,
            "Optimal Binary Test Accuracy": opt_rel,
        }
        print()
        wandb.log(dict, step=epoch)

        csv_writer.writerow([epoch,
                             train_acc_binary,
                             val_acc_binary,
                             best_val_rel,
                             test_acc_binary,
                             opt_rel])

        model.save_model(epoch, model_dir)
